import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import mpl_axes_aligner
import pandas as pd

plt.rcParams.update({
    'font.size': 9, 
    'axes.labelsize': 9,
    'xtick.labelsize': 7,
    'ytick.labelsize': 7,
    'legend.fontsize': 8,
    'font.family': 'DejaVu Serif'  # Or 'Times New Roman' for final draft
})



def build_increase_dataframe(log_precip_increase, frac_change_qv):
    """
    Construct a tidy DataFrame summarizing per-model changes in precipitation (ΔlnP)
    and fractional change in near surface specific humidity (Δlnqv) across transitions.

    Parameters:
        log_precip_increase (dict): 
            Nested dictionary mapping:
                model_name -> {transition_key -> ΔlnP value}
            Example:
                {"SAM": {"295-300K": 0.10, "300-305K": 0.11}, ...}
        
        frac_change_qv (dict): 
            Nested dictionary mapping:
                model_name -> {transition_key -> Δlnq* value}
            Example:
                {"SAM": {"295-300K": 0.08, "300-305K": 0.09}, ...}

    Returns:
        pd.DataFrame: A DataFrame with one row per model, containing:
            - 'model' column
            - 'dellnP_{transition}' columns for each ΔlnP value
            - 'dellnqv_{transition}' columns for each Δlnq* value
    """
    
    rows = []

    # Combine all model names from both input dictionaries
    all_models = set(log_precip_increase) | set(frac_change_qv)

    for model in all_models:
        row = {"model": model}

        # Extract ΔlnP values for all available transitions
        dlnP_vals = log_precip_increase.get(model, {})
        for k, v in dlnP_vals.items():
            row[f"dellnP_{k}"] = v  # Store as dellnP_{transition}

        # Extract Δlnq* values for all available transitions
        dlnq_vals = frac_change_qv.get(model, {})
        for k, v in dlnq_vals.items():
            row[f"dellnqv_{k}"] = v  # Store as dellnq*_{transition}

        # Add to list of rows
        rows.append(row)

    # Convert list of dictionaries to a pandas DataFrame
    df = pd.DataFrame(rows)
    return df

def panel_stair(ax, array2plot, boundaries, color, label, selected_labels, xlabel, pos, ylabel=None):
    """
    Plot a stair-step histogram on a given axis, with log-scaled x-axis and styled ticks/labels.

    Parameters:
        ax (matplotlib.axes.Axes): The subplot axis to draw on.
        array2plot (np.ndarray): The histogram values to plot (heights).
        boundaries (np.ndarray): The bin edges corresponding to the array.
        color (str or tuple): Color for the line.
        label (str): Label for the legend entry.
        selected_labels (list): Custom x-tick positions (usually bin centers).
        xlabel (str): Label for the x-axis.
        pos (str): Legend position string (unused here but kept for compatibility).
        ylabel (str, optional): Label for the y-axis (only applied if provided).
    """
    
    # Draw the stair-step line (step histogram with right-aligned edges)
    ax.stairs(array2plot, boundaries, edgecolor=color, linewidth=1.5, label=label)  # Thick colored line

    # Set x-axis to log scale to better represent precipitation distributions
    ax.set_xscale('log')

    # Configure custom x-ticks for consistent bin labeling
    ax.set_xticks(selected_labels)
    ax.set_xticklabels(selected_labels, rotation=28, ha="right")


    # Set x-axis label and range
    ax.set_xlabel(xlabel)  # Larger, bold x-label
    ax.set_xlim(min(selected_labels), max(selected_labels))  # Clip view to data range

    # Make y-ticks thick and readable
    ax.tick_params(axis="y")  # Thicker y-ticks
    ax.tick_params(axis="x", pad = 1)

    # Optionally set y-axis label
    if ylabel:
        ax.set_ylabel(ylabel)  # Apply only if provided



def plot_transition_panel(
    ax, model, Ts_low, Ts_high,
    dist_dic, df, boundaries,
    errors_dict=None,
    order = None,
    colors=None,
    selected_labels=None,
    xlabel="", ylabel=None,
    legend_pos="center left",
    show_legend=True,
    dist_type="amount",
):
    """
    Plot a single panel comparing distributions between two temperatures,
    with optional annotations for hydrological sensitivity and error.

    Parameters:
        ax (matplotlib.axes.Axes): The axis to plot on.
        model (str): Name of the model to plot.
        Ts_low (str): Lower temperature case (e.g., "295").
        Ts_high (str): Higher temperature case (e.g., "300").
        dist_dic (dict): Dictionary of distributions for each model and temp.
        df (pd.DataFrame): DataFrame containing shift/increase values.
        errors_dict (dict, optional): Error values by model and transition.
        colors (list): List of 3 colors for plotting.
        selected_labels (list): Tick labels for x-axis.
        xlabel (str): X-axis label.
        ylabel (str, optional): Y-axis label.
        legend_pos (str): Legend position string.
        show_legend (bool): Whether to show legend.
        dist_type (str): "amount" or "frequency".
        order (str) label for each individual panel. 
    """
    key = f"{Ts_low}-{Ts_high}K"
    
    row = df[df["model"] == model]
    if row.empty:
        print(f"⚠️ Model '{model}' not found in DataFrame.")
    else:
        try:
            shift = row[f"dellnqv_{key}"].values[0]
            increase = row[f"dellnP_{key}"].values[0]
            print('Got some keys man!')
        except KeyError as e:
            print(f"⚠️ Missing column: {e}")

    try:
        # Determine which distribution array to use
        dist_index = 1 if dist_type == "amount" else 0

        data_high = dist_dic[model][f"{Ts_high}"][dist_index]
        data_low = dist_dic[model][f"{Ts_low}"][dist_index] #add K here for FV3, for RCEMIP, remove!

        # Get values from DataFrame
        shift = df.loc[df["model"] == model, f"dellnqv_{key}"].values[0]
        increase = df.loc[df["model"] == model, f"dellnP_{key}"].values[0]
      
        try:
            decrease = df.loc[df["model"] == model, f"delta_{key}"].values[0]
        except KeyError:
            decrease = 0

        # Shift + scale
        bin_centers = 0.5 * (boundaries[:-1] + boundaries[1:])
        if dist_type == "amount":
            r_shifted = bin_centers * np.exp(shift)
            interp_vals = np.interp(bin_centers, r_shifted, data_low, left=0, right=0)
            shifted_scaled = np.exp(increase) * interp_vals
        else:  # frequency
            r_shifted = bin_centers * np.exp(shift)
            interp_vals = np.interp(bin_centers, r_shifted, data_low, left=0, right=0)
            shifted_scaled = np.exp(-decrease) * interp_vals
            #ax.set_ylim(0, 0.06)
        # Plot
        panel_stair(ax, data_low, boundaries, colors[1], f"{Ts_low}K", selected_labels, xlabel, legend_pos)
        panel_stair(ax, shifted_scaled, boundaries, colors[2], rf"{Ts_low}K b & $d_m$", selected_labels, xlabel, legend_pos)
        panel_stair(ax, data_high, boundaries, colors[0], f"{Ts_high}K", selected_labels, xlabel, legend_pos, ylabel)


        # --- Annotations
        d_m = shift - increase
        y_base = 0.46
        x_base = 0.08
        ax.text(x_base, y_base,  rf"$b$: {shift:.2f}", transform=ax.transAxes) # subtracted 0.07 from base
        ax.text(x_base, y_base - 0.09,  rf"$d_{{\mathrm{{m}}}}$: {d_m:.2f}", transform=ax.transAxes)
        ax.text(x_base, y_base - 0.18,  rf"$a$: {increase:.2f}", transform=ax.transAxes)
        ax.text(0.02, 0.98, model, transform=ax.transAxes, ha="left", va="top", color="gray", fontsize=9, zorder=10, clip_on=False)

        

        # Optional error
        if errors_dict and key in errors_dict.get(model, {}):
            error_val = errors_dict[model][key]
            err_y = 0.70 if dist_type == "amount" else 0.10
            ax.text(0.05, err_y, f"Error: {error_val*100:.0f}%", transform=ax.transAxes)

        if show_legend:
            ax.legend(
            loc="upper left",         # anchor point of legend box
            bbox_to_anchor=(-0.01, 0.95),  # (x, y) relative to axes [0–1]
            frameon=False
        )


        # Draw panel label (e.g. A, B, C)
        if order:
            ax.text(0.96, 0.98, order, transform=ax.transAxes,
                    ha='right', va='top', fontsize=9, fontweight='bold', color='gray')


    except KeyError as e:
        print(f"[{model} | {key}] Missing data: {e}")
    except IndexError:
        print(f"[{model} | {key}] No matching row in DataFrame.")




def plot_dual_row_transition_panels(
    dist_dic, df,
    boundaries, delLogR,
    selected_labels, colors,
    xlabel="mm/day", ylabel_freq=None, ylabel_amt=None,
    legend_pos="center left",
    models=None,
    transitions=None,
    fixed_model=None,
    fixed_transition=None,
    comparison_mode="model",  # "model" or "transition"
    which_rows="both",        # "frequency", "amount", or "both"
    errors_dict=None,
    orders = None,
    save=False,
    save_path=None
):
    """
    Create a panel plot showing either frequency, amount, or both types of 
    distributions for different models or transitions.

    Parameters:
        dist_dic (dict): Nested dictionary with model -> Ts -> (frequency, amount).
        df (pd.DataFrame): DataFrame with shift/increase info.
        boundaries (np.ndarray): Histogram bin edges.
        delLogR (float): Bin width (unused here but passed along).
        selected_labels (list): X-axis ticks.
        colors (list): Colors for plotting [ignored for Ts_high but includes fixed low/shifted].
        xlabel (str): Label for x-axis.
        ylabel_freq (str): Y-label for frequency panels.
        ylabel_amt (str): Y-label for amount panels.
        legend_pos (str): Matplotlib legend location.
        models (list): List of models to compare (used if mode is "model").
        transitions (list): List of (Ts_low, Ts_high) tuples (used if mode is "transition").
        fixed_model (str): The model to use when comparing across transitions.
        fixed_transition (tuple): (Ts_low, Ts_high) to use when comparing across models.
        comparison_mode (str): "model" (vary models) or "transition" (vary transitions).
        which_rows (str): "frequency", "amount", or "both".
        orders (list): Panel labels, optional
        errors_dict (dict): Optional dictionary of error values.
        save (bool): Whether to save figure.
        save_path (str): File path to save plot if `save` is True.
    """

    # --- Determine which rows to plot
    plot_frequency = which_rows in ["frequency", "both"]
    plot_amount = which_rows in ["amount", "both"]
    n_rows = int(plot_frequency) + int(plot_amount)

    if not (plot_frequency or plot_amount):
        raise ValueError("which_rows must be 'frequency', 'amount', or 'both'.")

    # --- Determine number of panels
    if comparison_mode == "model":
        if models is None or fixed_transition is None:
            raise ValueError("When comparison_mode='model', 'models' and 'fixed_transition' must be provided.")
        n_panels = len(models)
    elif comparison_mode == "transition":
        if fixed_model is None or transitions is None:
            raise ValueError("When comparison_mode='transition', 'fixed_model' and 'transitions' must be provided.")
        n_panels = len(transitions)
         # 🔥 NEW COLOR MAPPING BLOCK GOES HERE
        all_temps = sorted(set(float(T) for pair in transitions for T in pair))
        cmap = cm.get_cmap("viridis", len(all_temps))
        norm = mcolors.Normalize(vmin=min(all_temps), vmax=max(all_temps))
        temp_color_map = {T: cmap(norm(T)) for T in all_temps}

    else:
        raise ValueError("comparison_mode must be 'model' or 'transition'.")

    # --- Initialize figure
    fig = plt.figure(figsize=(6.7, 2.4)) #GRL appropriate sizeWhat 
    spec = gridspec.GridSpec(n_rows, n_panels, height_ratios=[1] * n_rows)
    shared_bot_ax = None

    # --- Main panel loop
    for col in range(n_panels):
        if comparison_mode == "model":
            model = models[col]
            Ts_low, Ts_high = fixed_transition
            panel_title = model
            curve_colors = colors  # unchanged for model comparison
        else:  # "transition"
            model = fixed_model
            Ts_low, Ts_high = transitions[col]
            panel_title = f"{Ts_low}K → {Ts_high}K"
    
            # ✅ Use consistent color per temperature
            Ts_low_f = float(Ts_low)
            Ts_high_f = float(Ts_high)
    
            high_color = temp_color_map[Ts_high_f]
            low_color = temp_color_map[Ts_low_f]
    
            blend_weight = 0.4
            shifted_rgb = [
                (1 - blend_weight) * low_color[i] + blend_weight * high_color[i]
                for i in range(3)
            ]
    
            curve_colors = [high_color, low_color, shifted_rgb]


        # --- Top row: frequency
        if plot_frequency:
            ax_top = fig.add_subplot(spec[0, col])

            if col > 0:
                ax_top.tick_params(labelleft=False)
                
            if model == "SAM":
                ax_top.set_ylim(bottom=0, top=0.048)
            else:
                ax_top.set_ylim(bottom=0, top = 0.10)

            plot_transition_panel(
                ax_top, model, Ts_low, Ts_high,
                dist_dic, df, boundaries,
                errors_dict,
                orders[col],
                curve_colors, selected_labels,
                xlabel="",
                ylabel= "" if col == 0 else None,
                legend_pos=legend_pos,
                show_legend=True,
                dist_type="frequency"
            )
            
            ax_top.tick_params(axis="x")

        # --- Bottom row: amount
        if plot_amount:
            row_index = 1 if plot_frequency else 0
            if col == 0:
                ax_bot = fig.add_subplot(spec[row_index, col])
                shared_bot_ax = ax_bot
            else:
                ax_bot = fig.add_subplot(spec[row_index, col])#, sharey=shared_bot_ax)
                
            # Remove tick labels on middle and right panels
            if col > 0:
                ax_bot.tick_params(labelleft=False)
                
            plot_transition_panel(
                ax_bot, model, Ts_low, Ts_high,
                dist_dic, df, boundaries,
                errors_dict,
                orders[col],
                curve_colors, selected_labels,
                xlabel=None,
                ylabel=None,
                legend_pos=legend_pos,
                show_legend=True,
                dist_type="amount"
            )
            
    fig.text(0.5, 1/20, xlabel, ha="center", fontsize=9)
    fig.text(0, 0.50, ylabel_amt, va="center", rotation="vertical", fontsize=9)
    fig.suptitle(r"FV3: Precip AmDists: shifted (b) & decreased ($d_m$).", fontsize=10, y = 1.01) #change plot title here
    plt.subplots_adjust(left=0.06, right=0.995, top=0.93, bottom=0.18, wspace=0.05, hspace=0.06)


    if save:

        plt.savefig(
            save_path,
            dpi=600,            # high-res if rasterized
            bbox_inches="tight", 
            pad_inches=0.02
        )


    plt.show()


def plot_2x4_transition_panels(
    dist_dic, dist_type, df,
    boundaries, delLogR,
    selected_labels, colors,
    xlabel="mm/day", ylabel_freq=None, ylabel_amt=None,
    legend_pos="center left",
    models=None,
    fixed_transition=None,
    comparison_mode="model",
    errors_dict=None,
    inset=False,
    save=False,
    save_path=None
):
    """
    Create a 2x4 panel plot comparing precipitation distributions across models
    for a fixed temperature transition (e.g., 295K → 305K).

    Behavior depends on `dist_type`:
    -------------------------------------------------
    - "amount":
        * Plots predicted distribution ("Ts_low b & dₘ") and target (Ts_high).
        * Raw low-temperature baseline is hidden.
        * Panels include annotations for `b`, `dₘ`, and relative L1 `Error`.
        * A single legend is shown in the bottom-right panel (panel 1,3).
        * All panels share the y-axis.

    - "frequency":
        * Plots predicted distribution ("Ts_low b & dₘ") and target (Ts_high).
        * Raw low-temperature baseline is hidden.
        * No annotations (`b`, `dₘ`, `Error`) are shown.
        * No legends are drawn.
        * Each panel has its own y-axis limits, based on the data range plotted.

    Parameters
    ----------
    dist_dic : dict
        model_name -> temperature -> (frequency_array, amount_array).
    dist_type : str
        "amount" or "frequency".
    df : pd.DataFrame
        DataFrame containing `dellnqv_*`, `dellnP_*`, and optional `delta_*`.
    boundaries : np.ndarray
        Histogram bin edges.
    delLogR : float
        Bin width (unused but kept for compatibility).
    selected_labels : list
        X-axis tick labels (e.g., [0.1, 1, 10, 100, 1000]).
    colors : list
        Colors for plotting [highT, lowT, shifted].
        Here: colors[0] = highT (target), colors[2] = shifted (prediction).
    xlabel, ylabel_freq, ylabel_amt : str
        Axis labels.
    legend_pos : str
        Legend location (used only in "amount" mode).
    models : list
        List of 8 model names.
    fixed_transition : tuple
        (Ts_low, Ts_high), e.g., ("295", "305").
    comparison_mode : str
        Must be "model".
    errors_dict : dict
        Nested dict: errors_dict[model][transition_key] = float error.
    inset : bool
        Whether to add an inset or not in functions.
    save : bool
        Whether to save figure.
    save_path : str
        File path if save=True.
    """

    assert comparison_mode == "model", "Only 'model' comparison supported."
    assert fixed_transition is not None, "fixed_transition must be provided."
    assert len(models) == 8, "Must provide exactly 8 models for 2x4 layout."

    import numpy as np
    import matplotlib.pyplot as plt
    
    Ts_low, Ts_high = fixed_transition
    key = f"{Ts_low}-{Ts_high}K"
    
    # --- Choose distribution type
    if dist_type == "amount":
        index = 1
        sharey_flag = True
    elif dist_type == "frequency":
        index = 0
        sharey_flag = False
    else:
        raise ValueError("dist_type must be 'amount' or 'frequency'")

    # --- Setup figure
    fig, axs = plt.subplots(2, 4, figsize=(6.7, 3), sharey=sharey_flag)

    # --- Filter out "MESONH" from the models to plot - we'll plot only 7 models
    models_to_plot = [model for model in models if model != "MESONH"]
    
    # --- Loop through filtered models (7 models total)
    for idx, model in enumerate(models_to_plot):
        row, col = divmod(idx, 4)
        ax = axs[row][col]

        # Extract distributions for this model and temperature transition
        low = dist_dic[model][f"{Ts_low}"][index]    # Low temperature baseline
        high = dist_dic[model][f"{Ts_high}"][index]  # High temperature target
        bin_centers = 0.5 * (boundaries[:-1] + boundaries[1:])

        # Get scaling parameters from DataFrame
        b = df.loc[df["model"] == model, f"dellnqv_{key}"].values[0]  # Shift parameter
        a = df.loc[df["model"] == model, f"dellnP_{key}"].values[0]   # Amount scaling
        try:
            d_m = df.loc[df["model"] == model, f"delta_{key}"].values[0]  # Frequency scaling
        except KeyError:
            d_m = b - a  # Fallback calculation

        # Shift + interpolate low distribution according to scaling theory
        r_shift = bin_centers * np.exp(b)
        low_shifted = np.interp(bin_centers, r_shift, low, left=0, right=0)

        # Prediction formula depends on distribution type
        if dist_type == "amount":
            pred = np.exp(a) * low_shifted     # canonical "amount" formula
        else:
            pred = np.exp(-d_m) * low_shifted  # canonical "frequency" formula

        # Add inset if requested (for quantile visualization)
        if inset:
            add_quantile_inset(
                ax=ax,
                bin_centers=bin_centers,
                pred=pred,
                high=high,
                boundaries=boundaries,
                quantile=0.85,
                mode=dist_type,
                pred_color=colors[2],
                high_color=colors[0],
                inset_size=0.35,
                avoid_locs=("lower left",),
                connectors=True
            )

        # --- Plot distributions: target (high), prediction, and baseline (low)
        ax.stairs(high, boundaries, edgecolor='k', linewidth=1.0, label=r'Target-$p_p$')
        ax.stairs(pred, boundaries, edgecolor="coral", linewidth=1.0, label=r'Predicted-$p_{sd}$')
        ax.stairs(low, boundaries, edgecolor='grey', linewidth=0.7, linestyle=':',
                  label=rf"{Ts_low} K")

        # --- Axis configuration
        ax.set_xscale("log")
        ax.set_xticks(selected_labels)
        ax.set_xlim(min(selected_labels), max(selected_labels))

        # Only show x-axis labels on bottom row
        if row == 1:
            ax.set_xticklabels(selected_labels, rotation=45, ha="right")
        else:
            ax.set_xticklabels([])

        # --- Simplified model names for display
        if model == "UKMOi-vn11.0-CASIM":
            display_name = "UKMO"
        elif model == "WRF_COL_CRM":
            display_name = "WRF"
        else:
            display_name = model

        # Add model name as panel title with appropriate positioning
        if dist_type == "amount":
            ax.text(0.05, 0.97, display_name, transform=ax.transAxes,
                    ha="left", va="top", color="gray", fontsize=9)
        else:
            # Special positioning for WRF in frequency mode
            if model != "WRF_COL_CRM":
                ax.text(0.05, 0.08, display_name, transform=ax.transAxes,
                     ha="left", va="bottom", fontsize=9, fontweight="bold", color="gray")
            else:
                ax.text(0.05, 0.05, display_name, transform=ax.transAxes,
                 ha="left", va="bottom", fontsize=9, fontweight="bold", color="gray")

        # --- Annotations for amount distributions only
        if dist_type == "amount":
          
            # Other panels: annotations at top  
            ax.text(0.01, 0.61, rf"$b$: {b:.2f}", transform=ax.transAxes, fontsize=8)
            ax.text(0.01, 0.49, rf"$d$: {d_m:.2f}", transform=ax.transAxes, fontsize=8)
            if errors_dict and model in errors_dict and key in errors_dict[model]:
                err_val = errors_dict[model][key]
                ax.text(0.01, 0.37, f"Error: {err_val*100:.0f}%", transform=ax.transAxes, fontsize=8)

        else:  # frequency mode
            # Auto y-limits based on plotted domain for frequency distributions
            mask = (bin_centers >= min(selected_labels)) & (bin_centers <= max(selected_labels))
            ymax = max(1.05 * max(np.max(high[mask]), np.max(pred[mask])), 
                      1.05 * max(np.max(low[mask]), np.max(pred[mask])))
            ax.set_ylim(0, ymax)

            # Format y-axis for scientific notation
            from matplotlib.ticker import ScalarFormatter
            ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
            ax.ticklabel_format(axis="y", style="sci", scilimits=(0,0))

    # --- SPECIAL HANDLING FOR LAST PANEL (position [1,3])
    empty_ax = axs[1][3]
    
    if dist_type == "amount":
        # For amount distributions: use last panel for legend (unframed)
        empty_ax.set_axis_off()  # Remove all axis elements
        # Create legend using proxy artists from one of the other panels
        handles, labels = axs[0][0].get_legend_handles_labels()
        empty_ax.legend(handles, labels, loc=legend_pos, frameon=False)
    else:
        # For frequency distributions: keep panel empty and turned off
        empty_ax.set_axis_off()

    # --- Shared axis labels
    fig.text(0.5, -1/25, xlabel, ha="center")
    if dist_type == "amount":
        fig.text(0.01, 0.5, ylabel_amt, va="center", rotation="vertical")
    else:
        fig.text(0, 0.5, ylabel_freq, va="center", rotation="vertical")

    # --- Figure title
    fig.suptitle(
        f"RCEMIP: Precipitation {dist_type.capitalize()} Distributions. "
        f"{Ts_low}-{Ts_high}K",
        y=1, fontsize=10
    )

    # --- Adjust layout for better spacing
    plt.subplots_adjust(left=0.08, right=0.98, top=0.90, bottom=0.10,
                        wspace=0.12, hspace=0.1)

    # --- Save if requested
    if save and save_path:
        plt.savefig(save_path, dpi=600, bbox_inches="tight", pad_inches=0.02)

    plt.show()

def plot_1x3_freq_amt(
    dist_dic, df, boundaries,
    fixed_model, transitions,
    selected_labels,
    dist_type="freq",              # "freq" or "amt"
    plot_difference=False,
    title="",
    xlabel=r"mm/day [rr]",
    ylabel_freq=r"Prob. / $\Delta \log(\mathrm{rr})$",
    ylabel_amt=r"Prob. / $\Delta \log(\mathrm{rr})\cdot rr$",
    ylabel_diff=r"Target - Predicted",
    legend_pos="upper left",
    save=False,
    save_path=None,
    errors_dict=None
):
    """
    Plot a 1x3 panel showing either frequency or amount precipitation distributions 
    across temperature transitions for a single model. Optionally adds a second y-axis 
    to visualize the difference between the predicted and target high-T distribution.

    Parameters:
        dist_dic (dict): Dictionary of model output distributions.
                         Format: model -> temp (e.g., '280K') -> (freq, amt)
        df (DataFrame): DataFrame containing b, a, and optionally d_m shift parameters
        boundaries (ndarray): Bin edges for distributions
        fixed_model (str): Model name to plot
        transitions (list): List of temperature pairs [(low, high), ...]
        selected_labels (list): X-ticks to use for log-scale x-axis
        dist_type (str): "freq" or "amt", determines what to plot
        plot_difference (bool): If True, adds secondary y-axis showing prediction error
        title (str): Plot title
        xlabel (str): X-axis label
        ylabel_freq (str): Y-axis label for frequency
        ylabel_amt (str): Y-axis label for amount
        ylabel_diff (str): Y-axis label for secondary axis (difference)
        legend_pos (str): Position for legend
        save (bool): If True, saves the figure
        save_path (str): Path to save the figure (required if save is True)
        errors_dict (dict): Optional. model -> "Ts_low-Ts_highK" -> float (relative L1 error)
    """
    panel_name = ['A', 'B', 'C']
    
    bin_centers = 0.5 * (boundaries[:-1] + boundaries[1:])
    fig, axs = plt.subplots(1, 3, figsize=(6.7, 2.4),sharey=True)

    all_T = sorted({float(T) for ab in transitions for T in ab})
    cmap = cm.get_cmap("viridis", len(all_T))
    norm = mcolors.Normalize(vmin=min(all_T), vmax=max(all_T))
    Tcolor = {str(int(T)): cmap(norm(float(T))) for T in all_T}
    grey = "0.6"  # 280K reference

    def _get(df, model, key):
        return (df.loc[df["model"] == model, f"dellnqv_{key}"].values[0],
                df.loc[df["model"] == model, f"dellnP_{key}"].values[0])

    for col, (Ts_low, Ts_high) in enumerate(transitions):
        key = f"{Ts_low}-{Ts_high}K"
        ax = axs[col]
        ax_diff = ax.twinx() if plot_difference else None

        # Set colors
        if Ts_low == 280:
            low_color = grey
        else:
            low_color = Tcolor[str(int(Ts_low))]
        high_color = Tcolor[str(int(Ts_high))]
        blend_color = [
            0.6 * np.array(mcolors.to_rgb(low_color))[i] +
            0.4 * np.array(mcolors.to_rgb(high_color))[i]
            for i in range(3)
        ]

        # Extract model distributions
        freq_low, amt_low = dist_dic[fixed_model][f"{Ts_low}K"]
        freq_high, amt_high = dist_dic[fixed_model][f"{Ts_high}K"]

        # Choose distributions to plot
        low = freq_low if dist_type == "freq" else amt_low
        high = freq_high if dist_type == "freq" else amt_high

        # Compute shift and prediction
        b, a = _get(df, fixed_model, key)
        try:
            d_m = df.loc[df["model"] == fixed_model, f"delta_{key}"].values[0]
        except KeyError:
            d_m = b - a

        r_shift = bin_centers * np.exp(b)
        low_interp = np.interp(bin_centers, r_shift, low, left=0, right=0)
        predicted = np.exp(a if dist_type == "amt" else -d_m) * low_interp

        # Plot true distributions
        if Ts_low == 280:
            ax.stairs(low, boundaries, edgecolor='grey', linewidth=0.7,linestyle=":", label=f"{Ts_low} K")
            
        ax.stairs(high, boundaries, edgecolor='k', linewidth=1, label=r'Target-$f_p$')
        ax.stairs(predicted, boundaries, edgecolor="darkseagreen", linewidth=1, label=r'Prediction-$f_{sd}$')
        

        # Plot difference if requested
        if plot_difference:
            diff = high - predicted
            ax_diff.stairs(diff, boundaries, edgecolor="black", linewidth=0.9, linestyle="--")

            # Align y=0 between primary and secondary axes using mpl_axes_aligner
            mpl_axes_aligner.align.yaxes(ax, 0, ax_diff, 0, 0.5)

            # Optional: match ticks only on third panel
            if col == 2:
                ax_diff.set_ylabel(ylabel_diff)
            else:
                ax_diff.set_yticks([])
                ax_diff.set_ylabel("")

        # Axes formatting
        ax.set_xscale("log")
        ax.set_xlim(min(selected_labels), max(selected_labels))
        if dist_type == "freq":
            ax.set_ylim(0, 0.10)
        ax.set_xticks(selected_labels)
        ax.set_xticklabels(selected_labels, rotation=45, ha="right")
        ax.set_title(f"{Ts_low}-{Ts_high}K")

        if col == 0:
            ax.set_ylabel(ylabel_freq if dist_type == "freq" else ylabel_amt)
        else:
            ax.tick_params(labelleft=False)

        # Annotations
        ax.text(0.65, 0.42, rf"$b$: {b:.2f}", transform=ax.transAxes)
        ax.text(0.65, 0.30, rf"$d$: {d_m:.2f}", transform=ax.transAxes)
        ax.text(0.90, 0.97, panel_name[col], transform=ax.transAxes,
                    ha="left", va="top", color="gray", fontsize=9)
        if errors_dict:
            try:
                err = errors_dict[fixed_model][key]
                ax.text(0.03, 0.15, f"Error: {err*100:.0f}%", transform=ax.transAxes)
            except KeyError:
                pass

        if col == 0:
            ax.legend(loc=legend_pos, frameon=False)

    fig.text(0.5, 0.02, xlabel, ha="center")
    if title:
        fig.suptitle(title, y=1.02)

    fig.subplots_adjust(left=0.09, right=0.98, top=0.82, bottom=0.2, wspace=0.05)

    if save and save_path:
        plt.savefig(save_path, dpi=600, bbox_inches="tight", pad_inches=0.02)

    plt.show()


def plot_2x3_freq_amt(
    dist_dic, df, boundaries,
    fixed_model, transitions,
    selected_labels,
    title="",
    xlabel=r"mm/day [rr]",
    ylabel_freq=r"Prob. / $\Delta \log(\mathrm{rr})$",
    ylabel_amt=r"Prob. / $\Delta \log(\mathrm{rr})\cdot rr$",
    legend_pos="upper left",
    save=False, save_path=None,
    errors_dict=None
):
    """
    Plot a 2x3 figure showing precipitation distributions for a model across temperature transitions.
    Focus is on clarity: each high-T distribution is shown once, 280K is greyed out, and predictions
    are compared to targets using color-consistent styling.

    Parameters:
        ...
        errors_dict (dict): Optional. Nested dict of relative L1 errors:
            model -> "Ts_low-Ts_highK" -> float
    """

    all_T = sorted({float(T) for ab in transitions for T in ab})
    cmap = cm.get_cmap("plasma", len(all_T))
    norm = mcolors.Normalize(vmin=min(all_T), vmax=max(all_T))
    Tcolor = {str(int(T)): cmap(norm(float(T))) for T in all_T}
    grey = "0.6"  # Color for 280K reference

    def _get(df, model, key):
        return (df.loc[df["model"] == model, f"dellnqv_{key}"].values[0],
                df.loc[df["model"] == model, f"dellnP_{key}"].values[0])

    bin_centers = 0.5 * (boundaries[:-1] + boundaries[1:])
    fig, axs = plt.subplots(2, 3, figsize=(6.7, 3.8), sharex=False, sharey=False)

    for j in (1, 2):
        axs[0, j].sharey(axs[0, 0])
        axs[1, j].sharey(axs[1, 0])

    plt.subplots_adjust(left=0.09, right=0.995, top=0.88, bottom=0.22,
                        wspace=0.18, hspace=0.15)

    for col, (Ts_low, Ts_high) in enumerate(transitions):
        key = f"{Ts_low}-{Ts_high}K"
        axF = axs[0, col]
        axA = axs[1, col]

        # Handle colors
        if Ts_low == 280:
            low_color = grey
            blend_color = [
                0.6 * np.array(mcolors.to_rgb(Tcolor["280"]))[i] +
                0.4 * np.array(mcolors.to_rgb(Tcolor["290"]))[i]
                for i in range(3)
            ]
        else:
            low_color = Tcolor[str(int(Ts_low))]
            high_color = Tcolor[str(int(Ts_high))]
            blend_color = [
                0.6 * np.array(mcolors.to_rgb(low_color))[i] +
                0.4 * np.array(mcolors.to_rgb(high_color))[i]
                for i in range(3)
            ]

        high_color = Tcolor[str(int(Ts_high))]

        # Extract distributions
        freq_low = dist_dic[fixed_model][f"{Ts_low}K"][0]
        freq_high = dist_dic[fixed_model][f"{Ts_high}K"][0]
        amt_low = dist_dic[fixed_model][f"{Ts_low}K"][1]
        amt_high = dist_dic[fixed_model][f"{Ts_high}K"][1]

        # Shifts and interpolations
        b, a = _get(df, fixed_model, key)
        try:
            d_m = df.loc[df["model"] == fixed_model, f"delta_{key}"].values[0]
        except KeyError:
            d_m = b - a

        r_shift = bin_centers * np.exp(b)
        freq_pred = np.exp(-d_m) * np.interp(bin_centers, r_shift, freq_low, left=0, right=0)
        amt_pred = np.exp(a) * np.interp(bin_centers, r_shift, amt_low, left=0, right=0)

        # --- Frequency Panel ---
        if Ts_low == 280:
            axF.stairs(freq_low, boundaries, edgecolor=low_color, linewidth=0.7, label=f"{Ts_low}K")

        if col in [0, 1, 2]:
            axF.stairs(freq_high, boundaries, edgecolor='orangered', linewidth=0.8, label=f"{Ts_high}K")

        axF.stairs(freq_pred, boundaries, edgecolor='mediumpurple', linewidth=1.0, label=rf"{Ts_low}K b & $d_m$")

        axF.set_xscale("log")
        axF.set_xticks(selected_labels)
        axF.set_xticklabels([])
        axF.set_xlim(min(selected_labels), max(selected_labels))
        axF.set_ylim(0, 0.10)

        axF.text(0.60, 0.62, rf"$b$: {b:.2f}", transform=axF.transAxes)
        axF.text(0.60, 0.50, rf"$d_m$: {d_m:.2f}", transform=axF.transAxes)

        if col == 0:
            axF.set_ylabel(ylabel_freq)
        else:
            axF.tick_params(labelleft=False)

        axF.text(0.96, 0.97, chr(ord('A')+col), transform=axF.transAxes,
                 ha="right", va="top", fontsize=9, fontweight="bold", color="gray")

        # --- Amount Panel ---
        if Ts_low == 280:
            axA.stairs(amt_low, boundaries, edgecolor=low_color, linewidth=0.7, label=f"{Ts_low}K")

        if col in [0, 1, 2]:
            axA.stairs(amt_high, boundaries, edgecolor='orangered', linewidth=.8, label=f"{Ts_high}K")

        axA.stairs(amt_pred, boundaries, edgecolor='mediumpurple', linewidth=1, label=rf"{Ts_high}K Predicted")

        axA.set_xscale("log")
        axA.set_xticks(selected_labels)
        axA.set_xticklabels(selected_labels, rotation=45, ha="right")
        axA.set_xlim(min(selected_labels), max(selected_labels))

        if col == 0:
            axA.set_ylabel(ylabel_amt)
        else:
            axA.tick_params(labelleft=False)

        axA.text(0.96, 0.97, chr(ord('D')+col), transform=axA.transAxes,
                 ha="right", va="top", fontsize=9, fontweight="bold", color="gray")

        # Annotate error if available
        if errors_dict:
            try:
                err = errors_dict[fixed_model][key]
                axA.text(0.08, 0.35, f"Error: {err*100:.0f}%", transform=axA.transAxes)
            except KeyError:
                pass

        axA.legend(loc=legend_pos, bbox_to_anchor=(-1/30, 1.0), frameon=False)


    fig.text(0.5, 0.12, xlabel, ha="center")
    fig.subplots_adjust(wspace=0.07)  # 👈 Adjust inter-panel spacing
    if title:
        fig.suptitle(title, y=0.95)

    if save and save_path:
        plt.savefig(save_path, dpi=600, bbox_inches="tight", pad_inches=0.02)

    plt.show()



def compute_distribution_errors(dist_dic, df, boundaries, models, transitions, dist_type="amount"):
    """
    Compute relative L1 error between actual (high-temperature) and transformed (shifted+scaled low-temperature)
    precipitation distributions for multiple models and temperature transitions.

    Parameters:
        dist_dic (dict): Nested dictionary of distributions structured as:
            model_name -> temperature_str -> (frequency_array, amount_array)

        df (pd.DataFrame): DataFrame containing shift and increase values:
            - Columns like 'dellnP_295-305K', 'dellnqv_295-305K', optionally 'delta_295-305K'

        boundaries (np.ndarray): Array of bin edges used for histograms

        models (list[str]): List of model names to evaluate

        transitions (list[tuple[str, str]]): List of (Ts_low, Ts_high) temperature string pairs, e.g., ('295', '305')

        dist_type (str): Either "amount" or "frequency"; determines which distribution index to use

    Returns:
        dict: Nested dictionary of relative L1 errors:
            model_name -> "Ts_low-Ts_highK" -> float (L1 error)
    """

    errors_dict = {}  # Final nested output
    index = 1 if dist_type == "amount" else 0

    # Compute bin centers from bin edges
    bin_centers = 0.5 * (boundaries[:-1] + boundaries[1:])

    for model in models:
        model_errors = {}  # Per-model container

        for Ts_low, Ts_high in transitions:
            key = f"{Ts_low}-{Ts_high}K"
            print(key)

            try:
                # Extract raw distributions
                data_low = dist_dic[model][Ts_low][index] #remove K for RCEMIP
                data_high = dist_dic[model][Ts_high][index]

                # Get shift and scale from DataFrame
                shift = df.loc[df["model"] == model, f"dellnqv_{key}"].values[0]
                if dist_type == "amount":
                    scale = df.loc[df["model"] == model, f"dellnP_{key}"].values[0]
                else:
                    scale = df.loc[df["model"] == model, f"delta_{key}"].values[0]

                # Shift x-axis by applying exp(shift) factor
                r_shifted = bin_centers * np.exp(shift)

                # Interpolate low-temp distribution onto original bin centers
                interp_vals = np.interp(bin_centers, r_shifted, data_low, left=0, right=0)

                # Apply scaling to get purple curve (shifted + scaled)
                shifted_scaled = np.exp(scale) * interp_vals

                # Compute relative L1 error: total deviation relative to target (crimson)
                error = np.sum(np.abs(shifted_scaled - data_high)) / np.sum(data_high)

                model_errors[key] = error

            except KeyError as e:
                print(f"[Missing key] {model} | {key}: {e}")
            except IndexError:
                print(f"[Missing row] {model} | {key}: No matching row in DataFrame")

        errors_dict[model] = model_errors

    return errors_dict

from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset
from matplotlib.ticker import ScalarFormatter
import matplotlib.patches as mpatches


def add_quantile_inset(
    ax,
    bin_centers,
    pred,
    high,
    boundaries,
    quantile=0.90,
    mode="frequency",                 # "frequency" | "amount" (just for heuristics)
    pred_color="purple",
    high_color="crimson",
    inset_size=0.35,                  # fraction of parent axis
    inset_loc_candidates=("center left", "upper right", "lower left", "lower right"),
    avoid_locs=("lower left",),       # corners to down-weight (e.g., where a legend might live)
    connectors=False,                 # draw connectors between main and inset (optional)
    rect_alpha=0.12,                  # alpha for highlight rectangle on parent
):
    """
    Add an inset axis to `ax` that zooms into the upper tail of the x-domain,
    starting at the specified quantile of `bin_centers`. Automatically places
    the inset in the least-crowded corner of the panel.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        Parent axes to which the inset is attached.
    bin_centers : (N,) ndarray
        Centers corresponding to the histogram `boundaries`.
    pred, high : (N,) ndarray
        Arrays to plot (prediction and target). Length must be len(boundaries)-1.
    boundaries : (N+1,) ndarray
        Bin edges used by `stairs`.
    quantile : float in (0,1)
        Quantile threshold for zoom start (e.g., 0.95 => top 5% of x).
    mode : {"frequency", "amount"}
        Only used to bias placement (amount panels might have legends).
    pred_color, high_color : color
        Line colors in the inset, typically matching the parent.
    inset_size : float
        Fractional size of inset relative to the parent axes.
    inset_loc_candidates : tuple[str]
        Candidate corners to consider.
    avoid_locs : tuple[str]
        Corners to down-weight (e.g., legend corner).
    connectors : bool
        If True, draw connectors between parent region and inset.
    rect_alpha : float
        Transparency for the parent highlight rectangle.

    Returns
    -------
    inset_ax : matplotlib.axes.Axes
        The created inset axis.
    """

    # --- 1) Quantile-cut zoom domain ----------------------------------------
    x_cut = np.quantile(bin_centers, quantile)
    # Use index so we can slice both values and edges consistently for stairs
    idx0 = int(np.searchsorted(bin_centers, x_cut, side="left"))
    # defend against edge-cases
    idx0 = min(max(idx0, 0), len(bin_centers) - 1)

    pred_zoom = pred[idx0:]
    high_zoom = high[idx0:]
    edges_zoom = boundaries[idx0:]             # len(edges_zoom) == len(values)+1
    if pred_zoom.size == 0 or high_zoom.size == 0 or edges_zoom.size < 2:
        # Nothing to show; bail gracefully
        return None

    x_min_zoom, x_max_zoom = edges_zoom[0], edges_zoom[-1]

    # --- 2) Choose least-crowded corner automatically -----------------------
    # Use current xlim of parent to estimate "left" vs "right" crowding
    xlo, xhi = ax.get_xlim()
    xmid = np.sqrt(xlo * xhi) if (xlo > 0 and xhi > 0 and ax.get_xscale() == "log") else 0.5 * (xlo + xhi)

    # windows: left half vs right half (on log scale, split at geometric middle)
    left_mask  = bin_centers <= xmid
    right_mask = bin_centers >  xmid

    # crowd score = max amplitude in that half; lower is better
    left_score  = 0.0 if not left_mask.any()  else float(max(np.max(pred[left_mask]),  np.max(high[left_mask])))
    right_score = 0.0 if not right_mask.any() else float(max(np.max(pred[right_mask]), np.max(high[right_mask])))

    # map candidate corners to the side they occupy to pick the lower score
    side_score = {
        "upper left":  left_score,
        "lower left":  left_score,
        "upper right": right_score,
        "lower right": right_score,
    }

    # Down-weight avoided corners slightly (e.g., legend corner in amount panels)
    # This nudges placement away from those corners if scores are similar.
    bias = {loc: (1.15 if (loc in avoid_locs and mode == "amount") else 1.0) for loc in inset_loc_candidates}

    ranked = sorted(inset_loc_candidates, key=lambda loc: side_score.get(loc, 1e9) * bias.get(loc, 1.0))
    best_loc = ranked[0]

    # --- 3) Create the inset and plot zoomed curves -------------------------
    inset_ax = inset_axes(
        ax,
        width=f"{int(inset_size*100)}%",
        height=f"{int(inset_size*100)}%",
        loc=best_loc,
        borderpad=0.4,
    )

    # Stairs require matching (values, edges). We already sliced consistent pairs above.
    inset_ax.stairs(pred_zoom, edges_zoom, edgecolor=pred_color, linewidth=1.2)
    inset_ax.stairs(high_zoom, edges_zoom, edgecolor=high_color, linewidth=1.2)

    # Match x-scale, set zoomed x-limits
    inset_ax.set_xscale(ax.get_xscale())
    inset_ax.set_xlim(x_min_zoom, x_max_zoom)

    # Auto y-limit on *zoomed* domain only (not global) with a little headroom
    ymax_zoom = 1.05 * max(float(np.max(pred_zoom)), float(np.max(high_zoom)))
    inset_ax.set_ylim(0.0, ymax_zoom if np.isfinite(ymax_zoom) and ymax_zoom > 0 else 1.0)

    # Simplify ticks/labels inside inset
    inset_ax.tick_params(axis="both", which="both", labelsize=7, length=2)
    inset_ax.set_xlabel("")
    inset_ax.set_ylabel("")
    # Optional: sci-notation for frequency
    if mode == "frequency":
        inset_ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        inset_ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))

    # --- 4) Optional: highlight & connectors on parent ----------------------
    if connectors:
        # shaded rectangle on the parent axes to indicate zoom region
        y0, y1 = ax.get_ylim()
        rect = mpatches.Rectangle(
            (x_min_zoom, 0.0),
            width=(x_max_zoom - x_min_zoom),
            height=(y1 - 0.0),
            linewidth=0.0,
            facecolor="grey",
            alpha=rect_alpha
        )
        rect.set_transform(ax.transData)
        ax.add_patch(rect)

        # connectors (corners depend on loc; choose opposite corners)
        loc_map = {
            "upper left":  (2, 4),  # TL inset–TL data region
            "upper right": (1, 3),  # TR inset–TR data region
            "lower left":  (3, 1),  # BL inset–BL data region
            "lower right": (4, 2),  # BR inset–BR data region
        }
        loc1, loc2 = loc_map.get(best_loc, (1, 3))
        mark_inset(parent_axes=ax, inset_axes=inset_ax, loc1=loc1, loc2=loc2, fc="none", ec="0.5", lw=0.8)

    return inset_ax


def plot_4x4_transition_panels_with_insets(
    dist_dic, dist_type, df,
    boundaries, delLogR,
    selected_labels, colors,
    xlabel="mm/day", ylabel_freq=None, ylabel_amt=None,
    legend_pos="center left",
    models=None,
    fixed_transition=None,
    comparison_mode="model",
    errors_dict=None,
    zoom_percentile=0.99,
    save=False,
    save_path=None
):
    """
    Create a 4x4 panel plot comparing precipitation distributions across models
    for a fixed temperature transition (e.g., 295K → 305K), with dedicated panels
    for zoomed views of extreme precipitation.

    Layout:
    - Row 0: Zoomed insets for models in row 1 (showing extremes ≥ specified percentile)
    - Row 1: Main distributions for first 4 models  
    - Row 2: Main distributions for next 4 models + one inline zoomed inset
    - Row 3: Zoomed insets for first 2 models of row 2 + legend panel

    Visual features:
    - Rectangles in main panels showing zoom regions (grounded to x-axis)
    - Single connector lines from top-left of rectangles to top-left of zoom panels
    - Panel [2,2] shows full x-axis labels, zoom panels show zoom-domain labels
    - LaTeX-optimized sizing and fonts
    - Synchronized domains between rectangles and zoom panels

    Parameters
    ----------
    dist_dic : dict
        model_name -> temperature -> (frequency_array, amount_array).
    dist_type : str
        "amount" or "frequency".
    df : pd.DataFrame
        DataFrame containing `dellnqv_*`, `dellnP_*`, and optional `delta_*`.
    boundaries : np.ndarray
        Histogram bin edges.
    delLogR : float
        Bin width (unused but kept for compatibility).
    selected_labels : list
        X-axis tick labels (e.g., [0.1, 1, 10, 100, 1000]).
    colors : list
        Colors for plotting [highT, lowT, shifted].
    xlabel, ylabel_freq, ylabel_amt : str
        Axis labels.
    legend_pos : str
        Legend location (used for both distribution types).
    models : list
        List of 8 model names.
    fixed_transition : tuple
        (Ts_low, Ts_high), e.g., ("295", "305").
    comparison_mode : str
        Must be "model".
    errors_dict : dict
        Nested dict: errors_dict[model][transition_key] = float error.
    zoom_percentile : float
        Percentile threshold for zoom domain (e.g., 0.99 for 99th percentile).
    save : bool
        Whether to save figure.
    save_path : str
        File path if save=True.
    """

    assert comparison_mode == "model", "Only 'model' comparison supported."
    assert fixed_transition is not None, "fixed_transition must be provided."
    assert len(models) == 8, "Must provide exactly 8 models for 4x4 layout."

    from matplotlib.ticker import ScalarFormatter
    import matplotlib.patches as patches
    
    Ts_low, Ts_high = fixed_transition
    key = f"{Ts_low}-{Ts_high}K"
    
    # --- Choose distribution type
    if dist_type == "amount":
        index = 1
    elif dist_type == "frequency":
        index = 0
    else:
        raise ValueError("dist_type must be 'amount' or 'frequency'")

    # --- Create 4x4 figure with LaTeX-appropriate size
    fig, axs = plt.subplots(4, 4, figsize=(7.5, 5.5), sharey=False)

    # --- Filter out "MESONH" from the models to plot - we'll plot only 7 models
    models_to_plot = [model for model in models if model != "MESONH"]
    
    # --- Store connector information for drawing lines after all panels are created
    connector_info = []
    
    # --- Loop through filtered models (7 models total)
    for idx, model in enumerate(models_to_plot):
        # Extract distributions for this model and temperature transition
        low = dist_dic[model][f"{Ts_low}"][index]    # Low temperature baseline
        high = dist_dic[model][f"{Ts_high}"][index]  # High temperature target
        bin_centers = 0.5 * (boundaries[:-1] + boundaries[1:])

        # Get scaling parameters from DataFrame
        b = df.loc[df["model"] == model, f"dellnqv_{key}"].values[0]  # Shift parameter
        a = df.loc[df["model"] == model, f"dellnP_{key}"].values[0]   # Amount scaling
        try:
            d_m = df.loc[df["model"] == model, f"delta_{key}"].values[0]  # Frequency scaling
        except KeyError:
            d_m = b - a  # Fallback calculation

        # Shift + interpolate low distribution according to scaling theory
        r_shift = bin_centers * np.exp(b)
        low_shifted = np.interp(bin_centers, r_shift, low, left=0, right=0)

        # Prediction formula depends on distribution type
        if dist_type == "amount":
            pred = np.exp(a) * low_shifted     # canonical "amount" formula
        else:
            pred = np.exp(-d_m) * low_shifted  # canonical "frequency" formula

        # Calculate zoom domain based on specified percentile of target distribution
        cumsum = np.cumsum(high)
        cumsum_norm = cumsum / cumsum[-1]
        zoom_domain_min = bin_centers[np.where(cumsum_norm >= zoom_percentile)[0][0]]
        zoom_domain_max = max(selected_labels)

        # --- Determine panel positions and connector directions
        if idx < 4:
            # First 4 models: main panels in row 1, zoomed panels in row 0
            main_row, main_col = 1, idx
            zoom_row, zoom_col = 0, idx
            connector_direction = 'top_row'
        else:
            # Next 3 models: main panels in row 2
            main_row, main_col = 2, idx - 4
            if idx < 6:  # Models 5 and 6
                zoom_row, zoom_col = 3, idx - 4
                connector_direction = 'bottom_row'
            else:  # Model 7 (inline case)
                zoom_row, zoom_col = 2, 3
                connector_direction = 'odd_ball'
            
        # --- Plot main distribution panel
        ax_main = axs[main_row, main_col]
        
        # Plot main distributions
        ax_main.stairs(high, boundaries, edgecolor='k', linewidth=1.0, label=r'Target-$p_p$')
        ax_main.stairs(pred, boundaries, edgecolor="coral", linewidth=1.0, label=r'Predicted-$p_{sd}$')
        ax_main.stairs(low, boundaries, edgecolor='grey', linewidth=0.7, linestyle=':',
                      label=rf"{Ts_low} K")

        # --- Get sliced data arrays for the zoom domain (EXACTLY like inset function)
        idx0 = int(np.searchsorted(bin_centers, zoom_domain_min, side="left"))
        idx0 = min(max(idx0, 0), len(bin_centers) - 1)
        
        # Get sliced arrays for the zoom region
        pred_zoom = pred[idx0:]
        high_zoom = high[idx0:]
        low_zoom = low[idx0:]
        edges_zoom = boundaries[idx0:]

        # --- Calculate rectangle bounds using PROVEN inset logic
        if pred_zoom.size > 0 and high_zoom.size > 0:
            # Use EXACT same logic as working inset function
            y_max_zoom = 1.05 * max(float(np.max(pred_zoom)), float(np.max(high_zoom)))
            y_max_zoom = y_max_zoom if np.isfinite(y_max_zoom) and y_max_zoom > 0 else 1.0
        else:
            y_max_zoom = 1.0  # Same fallback as working inset function
        
        y_min_zoom = 0  # Force rectangle to be grounded to x-axis

                # --- Add zoom region rectangle to main panel
        rect = patches.Rectangle(
            (zoom_domain_min, y_min_zoom),  # bottom left corner at x-axis
            zoom_domain_max - zoom_domain_min,  # width
            y_max_zoom - y_min_zoom,  # height based on PROVEN inset logic
            linewidth=1.5, edgecolor='gray', facecolor='none',
            linestyle='--', alpha=0.7
        )
        ax_main.add_patch(rect)
        # --- Configure main panel x-axis
        ax_main.set_xscale("log")
        ax_main.set_xticks(selected_labels)
        ax_main.set_xlim(min(selected_labels), max(selected_labels))

        # --- Selective x-axis tick labels: only panel [2,2] shows full x-labels
        show_xlabels = (main_row == 2 and main_col == 2)  # Only panel [2,2] shows x-labels
        if show_xlabels:
            ax_main.set_xticklabels(selected_labels, rotation=45, ha="right")
        else:
            ax_main.set_xticklabels([])  # Hide tick labels but keep tick marks

        # --- Simplified model names for display
        if model == "UKMOi-vn11.0-CASIM":
            display_name = "UKMO"
        elif model == "WRF_COL_CRM":
            display_name = "WRF"
        else:
            display_name = model

        # Add model name as panel title with adjusted font size for LaTeX
        if dist_type == "amount":
            ax_main.text(0.05, 0.97, display_name, transform=ax_main.transAxes,
                        ha="left", va="top", color="gray", fontsize=8)
        else:
            if model != "WRF_COL_CRM":
                ax_main.text(0.05, 0.08, display_name, transform=ax_main.transAxes,
                         ha="left", va="bottom", fontsize=8, fontweight="bold", color="gray")
            else:
                ax_main.text(0.05, 0.05, display_name, transform=ax_main.transAxes,
                     ha="left", va="bottom", fontsize=8, fontweight="bold", color="gray")

        # --- Annotations for amount distributions only
        if dist_type == "amount":
            ax_main.text(0.01, 0.61, rf"$b$: {b:.2f}", transform=ax_main.transAxes, fontsize=7)
            ax_main.text(0.01, 0.49, rf"$d$: {d_m:.2f}", transform=ax_main.transAxes, fontsize=7)
            if errors_dict and model in errors_dict and key in errors_dict[model]:
                err_val = errors_dict[model][key]
                ax_main.text(0.01, 0.37, f"Error: {err_val*100:.0f}%", transform=ax_main.transAxes, fontsize=7)

        else:  # frequency mode
            mask = (bin_centers >= min(selected_labels)) & (bin_centers <= max(selected_labels))
            ymax = max(1.05 * max(np.max(high[mask]), np.max(pred[mask])), 
                      1.05 * max(np.max(low[mask]), np.max(pred[mask])))
            ax_main.set_ylim(0, ymax)
            # Force scientific notation for all frequency plots
            ax_main.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
            ax_main.ticklabel_format(axis="y", style="sci", scilimits=(0,0))

        # --- Create zoomed panel
        ax_zoom = axs[zoom_row, zoom_col]
        
        # Plot ACTUAL zoomed distributions (only the region in the rectangle)
        ax_zoom.stairs(high_zoom, edges_zoom, edgecolor='k', linewidth=1.0)
        ax_zoom.stairs(pred_zoom, edges_zoom, edgecolor="coral", linewidth=1.0)
        ax_zoom.stairs(low_zoom, edges_zoom, edgecolor='grey', linewidth=0.7, linestyle=':')

        # --- CRITICAL: Use the EXACT SAME domain as the rectangle
        ax_zoom.set_xlim(zoom_domain_min, zoom_domain_max)  # Same x-range as rectangle
        ax_zoom.set_xscale("log")

      
        
        # --- Simple x-axis labeling: let matplotlib auto-decide ticks
        zoom_show_xlabels = ((zoom_row == 0 and zoom_col in [0, 1, 2, 3]) or  # First row zoom panels
                           (zoom_row == 2 and zoom_col == 3) or              # Panel [2,3]
                           (zoom_row == 3 and zoom_col in [0, 1]))           # Panels [3,0], [3,1]
        
        if zoom_show_xlabels:
            ax_zoom.tick_params(axis='x', labelrotation=20, labelbottom=True)
        else:
            ax_zoom.tick_params(axis='x', labelbottom=False)
        
        # --- CRITICAL: Use the EXACT SAME y-range as the rectangle
        ax_zoom.set_ylim(0, y_max_zoom)  # Identical to rectangle height calculation
        
        # Configure zoom panel appearance with LaTeX-appropriate fonts
        ax_zoom.tick_params(axis='both', which='major', labelsize=7)
        
        # Force scientific notation for ALL zoom panels
        ax_zoom.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax_zoom.ticklabel_format(axis="y", style="sci", scilimits=(0,0))

        # Store connector information
        connector_info.append({
            'main_ax': ax_main,
            'zoom_ax': ax_zoom,
            'zoom_min': zoom_domain_min,
            'zoom_max': zoom_domain_max,
            'y_min_zoom': y_min_zoom,
            'y_max_zoom': y_max_zoom,
            'direction': connector_direction
        })

    # --- Draw simplified connector lines (single line per panel pair)
    for conn in connector_info:
        main_ax = conn['main_ax']
        zoom_ax = conn['zoom_ax']
        zoom_min = conn['zoom_min']
        y_max_zoom = conn['y_max_zoom']
        direction = conn['direction']
        
        # Get zoom panel corners
        zoom_xlim = zoom_ax.get_xlim()
        zoom_ylim = zoom_ax.get_ylim()
        
        if direction in ['top_row', 'bottom_row']:
            # Single connector: top-left rectangle to top-left zoom panel
            main_point = (zoom_min, y_max_zoom)   # top-left of rectangle
            zoom_point = (zoom_xlim[0], zoom_ylim[1])  # top-left of zoom panel
            
            line = patches.ConnectionPatch(
                xyA=main_point, xyB=zoom_point,
                coordsA="data", coordsB="data",
                axesA=main_ax, axesB=zoom_ax,
                color="gray", linewidth=0.8, linestyle="--", alpha=0.6
            )
            fig.add_artist(line)
        else:  # odd_ball - keep both connectors
            # Top-left to top-left and top-right to top-right
            main_points = [
                (zoom_min, y_max_zoom),  # top-left of rectangle
                (conn['zoom_max'], y_max_zoom)   # top-right of rectangle
            ]
            zoom_points = [
                (zoom_xlim[0], zoom_ylim[1]),  # top-left of zoom panel
                (zoom_xlim[1], zoom_ylim[1])   # top-right of zoom panel
            ]
            
            for main_point, zoom_point in zip(main_points, zoom_points):
                line = patches.ConnectionPatch(
                    xyA=main_point, xyB=zoom_point,
                    coordsA="data", coordsB="data",
                    axesA=main_ax, axesB=zoom_ax,
                    color="gray", linewidth=0.8, linestyle="--", alpha=0.6
                )
                fig.add_artist(line)

    # --- Add panel labels (A, B, C, D, etc.) to each panel
    panel_labels = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P']
    panel_idx = 0
    
    for row in range(4):
        for col in range(4):
            ax = axs[row][col]
            # Skip empty/legend panels that are turned off
            if (row == 3 and col == 2) or (row == 3 and col == 3):
                continue
                
            ax.text(0.95, 0.95, panel_labels[panel_idx], transform=ax.transAxes,
                    ha='right', va='top', fontsize=12) # bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.8)
            panel_idx += 1

    # --- Handle empty panels and legends
    # Panel [3,2] for legend
    ax_legend = axs[3, 2]
    ax_legend.set_axis_off()
    
    # Create legend using proxy artists from first main panel
    handles, labels = axs[1, 0].get_legend_handles_labels()
    ax_legend.legend(handles, labels, loc=legend_pos, frameon=False, fontsize=8)

    # Panel [3,3] remains empty
    axs[3, 3].set_axis_off()

    # --- Final figure configuration with LaTeX-appropriate styling
    # Shared axis labels
    fig.text(0.65, 0.04, xlabel, ha="center", fontsize=10)
    if dist_type == "amount":
        fig.text(0.02, 0.5, ylabel_amt, va="center", rotation="vertical", fontsize=10)
    else:
        fig.text(0.02, 0.5, ylabel_freq, va="center", rotation="vertical", fontsize=10)

    # Figure title
    fig.suptitle(
        f"RCEMIP: Precipitation {dist_type.capitalize()} Distributions. "
        f"{Ts_low}-{Ts_high}K (Zoom ≥ {zoom_percentile*100:.0f}th percentile)",
        y=0.95, fontsize=11
    )

    # Adjust layout for LaTeX compatibility
    plt.subplots_adjust(left=0.08, right=0.98, top=0.90, bottom=0.08,
                        wspace=0.20, hspace=0.30)

    # Save if requested
    if save and save_path:
        plt.savefig(save_path, dpi=600, bbox_inches="tight", pad_inches=0.02)

    plt.show()



    
